import numpy as np
import torch
import os
from ASVBaselineTool.feature_tools import *
import soundfile as sf
import torchaudio
import torch.nn as nn


#FGSM攻击框架
class ENS():
    # 模型  标签  原始数据集路径  对抗样本保存的路径
    #model1 rawnet model2 mfcc model3 lfcc model4 spec
    def __init__(self,model1,model2,model3,model4,labelPath,datasetPath,AdvsetPath,steps=80,eps=0.06):
        #获取label和fname
        with open(labelPath,"r") as f:
            Labels=f.readlines()
        self.AFnames=[item.split(" ")[0] for item in Labels]
        ALabels = [item.split(" ")[-1].strip() for item in Labels]
        #label meta
        self.LabelMeta={}
        for f,l in zip(self.AFnames,ALabels):
            if l == "genuine":
                l = torch.Tensor([0.0,1.0])
            elif l == "fake":
                l = torch.Tensor([1.0,0.0])
            self.LabelMeta[f]=l
        #保存的路径 和上面是同样索引
        self.models=[model1,model2,model3,model4]
        self.models[0].cuda()
        self.models[1].cuda()
        self.models[2].cuda()
        self.models[3].cuda()
        self.model0_to_eval()
        self.models[1].eval()
        self.models[2].eval()
        self.models[3].eval()

        self.lossfun=nn.CrossEntropyLoss()
        self.datasetPath=datasetPath
        #保存对抗样本路径
        self.AdvSavePath=AdvsetPath
        self.steps=steps
        self.eps = eps

    def get_grad(self,wav:torch.Tensor,targety):
        wav0=wav.detach().cpu()
        wav1 = wav.detach().cpu()
        wav2 = wav.detach().cpu()
        wav3 = wav.detach().cpu()
        wav0.requires_grad=True
        wav1.requires_grad = True
        wav2.requires_grad = True
        wav3.requires_grad = True
        wavfea = wav0.reshape(1, -1)
        wavfea=wavfea.cuda()
        mfcc40fea=get_MFCC_40(wav1)
        mfcc40fea=mfcc40fea.cuda()
        lfcc70fea=get_LFCC_70(wav2)
        lfcc70fea=lfcc70fea.cuda()
        spec2048fea = get_SPEC_2048(wav3)
        spec2048fea=spec2048fea.cuda()
        batch_y0=self.models[0](wavfea)
        batch_y1 = self.models[1](mfcc40fea)
        batch_y2 = self.models[2](lfcc70fea)
        batch_y3 = self.models[3](spec2048fea)
        loss0=self.lossfun(batch_y0,targety)
        loss1 = self.lossfun(batch_y1, targety)
        loss2 = self.lossfun(batch_y2, targety)
        loss3 = self.lossfun(batch_y3, targety)
        loss0.backward()
        loss1.backward()
        loss2.backward()
        loss3.backward()
        grad0=wav0.grad.data#WAVE
        grad1=wav1.grad.data
        grad2=wav2.grad.data
        grad3=wav3.grad.data
        grad0 /=torch.max(torch.abs(grad0))#WAVE
        grad1 /= torch.max(torch.abs(grad1))
        grad2 /= torch.max(torch.abs(grad2))
        grad3 /= torch.max(torch.abs(grad3))
        grad=(200*grad0+grad1+grad2+grad3)/4
        loss=(loss0+loss1+loss2+loss3)/4
        self.models[0].zero_grad()
        self.models[1].zero_grad()
        self.models[2].zero_grad()
        self.models[3].zero_grad()
        return grad,loss

    def get_pre_label(self,wav:torch.Tensor):
        wav=wav.detach().cpu()
        wav.requires_grad = True
        wavfea = wav.reshape(1, -1)
        wavfea=wavfea.cuda()
        mfcc40fea = get_MFCC_40(wav)
        mfcc40fea=mfcc40fea.cuda()
        lfcc70fea = get_LFCC_70(wav)
        lfcc70fea=lfcc70fea.cuda()
        spec2048fea = get_SPEC_2048(wav)
        spec2048fea=spec2048fea.cuda()
        batch_y0 = self.models[0](wavfea)
        batch_y1 = self.models[1](mfcc40fea)
        batch_y2 = self.models[2](lfcc70fea)
        batch_y3 = self.models[3](spec2048fea)
        som=torch.nn.Softmax(dim=1)
        batch_y0=som(batch_y0)
        batch_y1 = som(batch_y1)
        batch_y2 = som(batch_y2)
        batch_y3 = som(batch_y3)
        pre_y=batch_y0+batch_y1+batch_y2+batch_y3
        return pre_y.max(1)[1].item()

    def model0_to_eval(self):
        self.models[0].train()
        self.models[0].first_bn.eval()
        self.models[0].block0[0].bn2.eval()
        self.models[0].block1[0].bn1.eval()
        self.models[0].block1[0].bn2.eval()
        self.models[0].block2[0].bn1.eval()
        self.models[0].block2[0].bn2.eval()
        self.models[0].block3[0].bn1.eval()
        self.models[0].block3[0].bn2.eval()
        self.models[0].block4[0].bn1.eval()
        self.models[0].block4[0].bn2.eval()
        self.models[0].block5[0].bn1.eval()
        self.models[0].block5[0].bn2.eval()
        self.models[0].bn_before_gru.eval()

    def Attack(self,alpha=0.0001,stop=0.1):
        print("attack alpha:"+str(alpha))
        num_f=len(self.AFnames)
        print("总样本：" + str(num_f))
        #攻击成功的样本的转移率
        num_as=0
        for item in self.AFnames:
            src_wav,sr=sf.read(os.path.join(self.datasetPath,item))
            src_wav=torch.Tensor(src_wav)
            adv_wav=src_wav
            #计算预测概率
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            init_pred=self.get_pre_label(adv_wav)
            #获取原始分类
            y=self.LabelMeta[item]
            targety = -1 * y + 1
            targety = targety.cuda()
            targety = targety.reshape(1, 2)
            init_y = y.max(0, keepdim=True)[1].item()
            #如果等于就说明本来就分类错误，直接跳过
            if init_pred != init_y:
                print("原始分类错误 pass")
                continue
            #分类正确则攻击
            else:
                # 循环攻击核心代码PGD
                perturbed_wav = adv_wav.detach().clone().cpu()
                for i in range(self.steps):
                    perturbed_wav = perturbed_wav.detach().cpu()
                    # 获取梯度更新扰动
                    grad,loss=self.get_grad(wav=perturbed_wav,targety=targety)
                    sign_data_grad = grad.sign()
                    #生产无梯度样本
                    perturbed_wav = perturbed_wav.detach() - alpha * sign_data_grad
                    # 添加扰动保证才-eps,eps  且生成的音频范围在[-1,1]
                    delta = torch.clamp(perturbed_wav - adv_wav, min=-self.eps, max=self.eps)
                    perturbed_wav = torch.clamp(adv_wav + delta, min=-1, max=1)
                    if(loss.item()<stop):
                        break
                #攻击后重新计算特征送入网络判断
                perturbed_wav = perturbed_wav.detach().cpu()
                temp_init_pred=self.get_pre_label(perturbed_wav)
                #如果攻击失败
                if temp_init_pred == init_y:
                    print("攻击失败 pass")
                    continue
                else:
                    sp=os.path.join(self.AdvSavePath, item)
                    print("攻击成功并保存值路径:"+sp)
                    num_as+=1
                    perturbed_wav = perturbed_wav.detach().cpu().numpy()
                    sf.write(sp, perturbed_wav, sr)
        print("总样本：" + str(num_f))
        print("攻击成功样本数量:"+str(num_as))
        print("攻击成功率:"+str(num_as/num_f))




















